{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%autoreload 2\n",
    "\n",
    "from IPython import display\n",
    "\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import transforms, datasets\n",
    "\n",
    "from utils import Logger\n",
    "\n",
    "import tensorflow.compat.v1 as tf\n",
    "tf.disable_v2_behavior() \n",
    "from tensorflow import nn\n",
    "\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_FOLDER = './tf_data/DCGAN/CIFAR'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cifar_data():\n",
    "    compose = transforms.Compose(\n",
    "        [\n",
    "            transforms.Resize(64),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize((.5,), (.5,)),\n",
    "        ])\n",
    "    out_dir = '{}/dataset'.format(DATA_FOLDER)\n",
    "    return datasets.CIFAR10(root=out_dir, train=True, transform=compose, download=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = cifar_data()\n",
    "batch_size = 100\n",
    "dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
    "num_batches = len(dataloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "IMAGES_SHAPE = (64, 64, 3)\n",
    "NOISE_SIZE = 100"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Models "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Ops"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def default_conv2d(inputs, filters):\n",
    "    return layers.conv2d(\n",
    "        inputs,\n",
    "        filters=filters,\n",
    "        kernel_size=4,\n",
    "        strides=(2, 2),\n",
    "        padding='same',\n",
    "        data_format='channels_last',\n",
    "        use_bias=False,\n",
    "    )\n",
    "\n",
    "def default_conv2d_transpose(inputs, filters):\n",
    "    return layers.conv2d_transpose(\n",
    "        inputs,\n",
    "        filters=filters,\n",
    "        kernel_size=4,\n",
    "        strides=(2, 2),\n",
    "        padding='same',\n",
    "        data_format='channels_last',\n",
    "        use_bias=False,\n",
    "    )\n",
    "\n",
    "def noise(n_rows, n_cols):\n",
    "    return np.random.normal(size=(n_rows, n_cols))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Discriminator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def discriminator(x):\n",
    "    with tf.variable_scope(\"discriminator\", reuse=tf.AUTO_REUSE):\n",
    "        with tf.variable_scope(\"conv1\"):\n",
    "            conv1 = default_conv2d(x, 128)\n",
    "            conv1 = nn.leaky_relu(conv1,alpha=0.2)\n",
    "        \n",
    "        with tf.variable_scope(\"conv2\"):\n",
    "            conv2 = default_conv2d(conv1, 256)\n",
    "            conv2 = layers.batch_normalization(conv2)\n",
    "            conv2 = nn.leaky_relu(conv2,alpha=0.2)\n",
    "            \n",
    "        with tf.variable_scope(\"conv3\"):\n",
    "            conv3 = default_conv2d(conv2, 512)\n",
    "            conv3 = layers.batch_normalization(conv3)\n",
    "            conv3 = nn.leaky_relu(conv3,alpha=0.2)\n",
    "            \n",
    "        with tf.variable_scope(\"conv4\"):\n",
    "            conv4 = default_conv2d(conv3, 1024)\n",
    "            conv4 = layers.batch_normalization(conv4)\n",
    "            conv4 = nn.leaky_relu(conv4,alpha=0.2)\n",
    "        \n",
    "        with tf.variable_scope(\"linear\"):\n",
    "            linear = keras.layers.flatten(conv4)\n",
    "            linear = keras.layers.fully_connected(linear, 1)\n",
    "        \n",
    "        with tf.variable_scope(\"out\"):\n",
    "            out = nn.sigmoid(linear)\n",
    "    return out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generator(z):\n",
    "    with tf.variable_scope(\"generator\", reuse=tf.AUTO_REUSE):\n",
    "        \n",
    "        with tf.variable_scope(\"linear\"):\n",
    "            linear = keras.layers.fully_connected(z, 1024 * 4 * 4)\n",
    "            \n",
    "        with tf.variable_scope(\"conv1_transp\"):\n",
    "            # Reshape as 4x4 images\n",
    "            conv1 = tf.reshape(linear, (-1, 4, 4, 1024))\n",
    "            conv1 = default_conv2d_transpose(conv1, 512)\n",
    "            conv1 = layers.batch_normalization(conv1)\n",
    "            conv1 = nn.relu(conv1)\n",
    "        \n",
    "        with tf.variable_scope(\"conv2_transp\"):\n",
    "            conv2 = default_conv2d_transpose(conv1, 256)\n",
    "            conv2 = layers.batch_normalization(conv2)\n",
    "            conv2 = nn.relu(conv2)\n",
    "            \n",
    "        with tf.variable_scope(\"conv3_transp\"):\n",
    "            conv3 = default_conv2d_transpose(conv2, 128)\n",
    "            conv3 = layers.batch_normalization(conv3)\n",
    "            conv3 = nn.relu(conv3)\n",
    "            \n",
    "        with tf.variable_scope(\"conv4_transp\"):\n",
    "            conv4 = default_conv2d_transpose(conv3, 3)\n",
    "        \n",
    "        with tf.variable_scope(\"out\"):\n",
    "            out = tf.tanh(conv4)\n",
    "    return out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Graph Optimization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Real Input\n",
    "X = tf.placeholder(tf.float32, shape=(None, )+IMAGES_SHAPE)\n",
    "## Latent Variables / Noise\n",
    "Z = tf.placeholder(tf.float32, shape=(None, NOISE_SIZE))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Generator\n",
    "G_sample = generator(Z)\n",
    "# Discriminator\n",
    "D_real = discriminator(X)\n",
    "D_fake = discriminator(G_sample)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Generator\n",
    "G_loss = tf.reduce_mean(\n",
    "    tf.nn.sigmoid_cross_entropy_with_logits(\n",
    "        logits=D_fake, labels=tf.ones_like(D_fake)\n",
    "    )\n",
    ")\n",
    "\n",
    "# Discriminator\n",
    "D_loss_real = tf.reduce_mean(\n",
    "    tf.nn.sigmoid_cross_entropy_with_logits(\n",
    "        logits=D_real, labels=tf.ones_like(D_real)\n",
    "    )\n",
    ")\n",
    "\n",
    "D_loss_fake = tf.reduce_mean(\n",
    "    tf.nn.sigmoid_cross_entropy_with_logits(\n",
    "        logits=D_fake, labels=tf.zeros_like(D_fake)\n",
    "    )\n",
    ")\n",
    "\n",
    "D_loss = D_loss_real + D_loss_fake"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Optimizers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Obtain trainable variables for both networks\n",
    "train_vars = tf.trainable_variables()\n",
    "\n",
    "G_vars = [var for var in train_vars if 'generator' in var.name]\n",
    "D_vars = [var for var in train_vars if 'discriminator' in var.name]\n",
    "\n",
    "num_epochs = 200"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "G_opt = tf.train.AdamOptimizer(2e-4).minimize(G_loss, var_list=G_vars,)\n",
    "D_opt = tf.train.AdamOptimizer(2e-4).minimize(D_loss, var_list=D_vars,)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_test_samples = 16\n",
    "test_noise = noise(num_test_samples, NOISE_SIZE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Inits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 100\n",
    "NUM_EPOCHS = 200\n",
    "\n",
    "# Start interactive session\n",
    "session = tf.InteractiveSession()\n",
    "# Init Variables\n",
    "tf.global_variables_initializer().run()\n",
    "# Init Logger\n",
    "logger = Logger(model_name='DCGAN1', data_name='CIFAR10')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Start interactive session\n",
    "session = tf.InteractiveSession()\n",
    "# Init Variables\n",
    "tf.global_variables_initializer().run()\n",
    "\n",
    "# Iterate through epochs\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    for n_batch, (batch,_) in enumerate(dataloader):\n",
    "        \n",
    "        # 1. Train Discriminator\n",
    "        X_batch = batch.permute(0, 2, 3, 1).numpy()\n",
    "        feed_dict = {X: X_batch, Z: noise(BATCH_SIZE, NOISE_SIZE)}\n",
    "        _, d_error, d_pred_real, d_pred_fake = session.run(\n",
    "            [D_opt, D_loss, D_real, D_fake], feed_dict=feed_dict\n",
    "        )\n",
    "\n",
    "        # 2. Train Generator\n",
    "        feed_dict = {Z: noise(BATCH_SIZE, NOISE_SIZE)}\n",
    "        _, g_error = session.run(\n",
    "            [G_opt, G_loss], feed_dict=feed_dict\n",
    "        )\n",
    "        \n",
    "        if n_batch % 100 == 0:\n",
    "            display.clear_output(True)\n",
    "            # Generate images from test noise\n",
    "            test_images = session.run(\n",
    "                G_sample, feed_dict={Z: test_noise}\n",
    "            )\n",
    "            # Log Images\n",
    "            logger.log_images(test_images, num_test_samples, epoch, n_batch, num_batches, format='NHWC');\n",
    "            # Log Status\n",
    "            logger.display_status(\n",
    "                epoch, num_epochs, n_batch, num_batches,\n",
    "                d_error, g_error, d_pred_real, d_pred_fake\n",
    "            )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
